import os
import shutil

import numpy as np
import torch


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

# log文件的初始化
def init_logfile(filename: str, text: str):
    f = open(filename, 'w')
    f.write(text+"\n")
    f.close()

# log文件写入
def log(filename: str, text: str):
    f = open(filename, 'a')
    f.write(text+"\n")
    f.close()


# 是否需要计算梯度
def requires_grad_(model:torch.nn.Module, requires_grad:bool) -> None:
    for param in model.parameters():
        param.requires_grad_(requires_grad)


def copy_code(outdir):
    """Copies files to the outdir to store complete script with each experiment"""
    # embed()
    code = []
    exclude = set([])
    for root, _, files in os.walk("./code_sigma_random", topdown=True):
        for f in files:
            if not f.endswith('.py'):
                continue
            code += [(root,f)]

    for r, f in code:
        codedir = os.path.join(outdir,r)
        if not os.path.exists(codedir):
            os.mkdir(codedir)
        shutil.copy2(os.path.join(r,f), os.path.join(codedir,f))
    print("Code copied to '{}'".format(outdir))

def read_result_max_radius(path, iteration=100, max_iter=100, interval=1):
    with open(path, "r") as f:
        lines = f.readlines()
        lines = lines[1:]
        
        res_status = []
        res_sigma = []
        
        lines_count = len(lines)
        for i in range(lines_count // iteration):
            
            max_radius = 0
            sigma = 0
            right = False
            for j in range(iteration // interval):
                if j * interval >= max_iter:
                    break
                line = lines[i * iteration + j * interval]
                line = line.split("\t")
                
                # if line[-3] == "1":
                if True:
                    if float(line[-4]) >= max_radius:
                        max_radius = float(line[-4])
                        sigma = float(line[-1])
                        if line[-3] == "1":
                            right = True
                        else:
                            right = False
                else:
                    pass
            # if max_radius > 0:
            res_sigma.append(sigma)
            if right:
                res_status.append(True)
            else:
                res_status.append(False)

    return res_sigma, res_status


